Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @Artiprocher, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces support for FP8 (8-bit floating point) LoRA training for Qwen-Image models. This enhancement aims to significantly reduce memory consumption and potentially accelerate the training process by leveraging lower precision data types. The changes involve modifying the Qwen-Image pipeline to handle FP8 data, updating the LoRA utility functions to support mixed-precision training, and integrating a new command-line option to enable this feature in the training examples.
Highlights
- FP8 LoRA Training Support: Introduces comprehensive support for FP8 (8-bit floating point) LoRA training for Qwen-Image models, enabling more memory-efficient and potentially faster training.
- Pipeline-level FP8 Configuration: Adds a new internal method
_enable_fp8_lora_trainingwithin theQwenImagePipelineto configure and manage VRAM for FP8 operations across the text encoder, DIT, and VAE components. - LoRA Parameter Upcasting: Enhances the
add_lora_to_modelutility function to allow upcasting of LoRA parameters to a higher precision (e.g., bfloat16) during mixed-precision training, ensuring numerical stability. - Command-line Activation: Exposes a new command-line argument
--enable_fp8_trainingin the training script, providing a straightforward way to activate this new feature.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request adds support for FP8 LoRA training for qwen-image models. The changes span the pipeline definition, trainer utilities, and the example training script. The implementation is mostly sound, but I have identified a couple of areas for improvement regarding device handling and input parsing robustness. My review includes suggestions to address these points.
| offload_dtype=dtype, | ||
| offload_device="cuda", | ||
| onload_dtype=dtype, | ||
| onload_device="cuda", | ||
| computation_dtype=self.torch_dtype, | ||
| computation_device="cuda", |
There was a problem hiding this comment.
The device is hardcoded to "cuda". It's better to use self.device to be consistent with the rest of the class and allow users to specify a different device (e.g., cuda:1) during pipeline initialization.
| offload_dtype=dtype, | |
| offload_device="cuda", | |
| onload_dtype=dtype, | |
| onload_device="cuda", | |
| computation_dtype=self.torch_dtype, | |
| computation_device="cuda", | |
| offload_dtype=dtype, | |
| offload_device=self.device, | |
| onload_dtype=dtype, | |
| onload_device=self.device, | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, |
| if model_paths is not None: | ||
| model_paths = json.loads(model_paths) | ||
| model_configs += [ModelConfig(path=path) for path in model_paths] | ||
| model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] | ||
| if model_id_with_origin_paths is not None: | ||
| model_id_with_origin_paths = model_id_with_origin_paths.split(",") | ||
| model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] | ||
| model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] |
There was a problem hiding this comment.
The logic for parsing model_id_with_origin_paths is not robust. Using item.split(':', 1) is safer than item.split(':') as it prevents errors if the origin_file_pattern contains colons (e.g., in a Windows path). Additionally, explicitly checking that the split results in two parts before unpacking will prevent IndexError if an entry is malformed.
This refactoring improves robustness and readability.
| if model_paths is not None: | |
| model_paths = json.loads(model_paths) | |
| model_configs += [ModelConfig(path=path) for path in model_paths] | |
| model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] | |
| if model_id_with_origin_paths is not None: | |
| model_id_with_origin_paths = model_id_with_origin_paths.split(",") | |
| model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] | |
| model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] | |
| if model_paths is not None: | |
| model_paths = json.loads(model_paths) | |
| model_configs.extend(ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths) | |
| if model_id_with_origin_paths is not None: | |
| model_id_with_origin_paths = model_id_with_origin_paths.split(",") | |
| for item in model_id_with_origin_paths: | |
| parts = item.split(":", 1) | |
| if len(parts) == 2: | |
| model_configs.append(ModelConfig(model_id=parts[0], origin_file_pattern=parts[1], offload_dtype=offload_dtype)) |
support qwen-image fp8 lora training
No description provided.